#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2016-2018, Niklas Hauser
# Copyright (c) 2023, Christopher Durand
#
# This file is part of the modm project.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# -----------------------------------------------------------------------------

from collections import defaultdict
import re

props = {}

class Instance(Module):
    def __init__(self, driver, instance):
        self.driver = driver
        self.instance = int(instance)
        self.vectors = None
        self.type = "general_purpose"
        if self.instance in [1, 8, 20]:
            self.type = "advanced"
        elif self.instance in [6, 7, 18]:
            self.type = "basic"

    def init(self, module):
        module.name = str(self.instance)
        module.description = "Instance {}".format(self.instance)

    def prepare(self, module, options):
        return True

    def validate(self, env):
        vecmap = {
            "_UP": ["Update"],
            "_TRG": ["Trigger"],
            "_BRK": ["Break"],
            "_COM": ["COM"],
            "_CC": ["CaptureCompare1", "CaptureCompare2",
                     "CaptureCompare3", "CaptureCompare4"],
        }
        self.vectors = {irq:[] for irq in props["timer_vectors"][self.instance]}
        for irq in self.vectors.keys():
            for part, flags in vecmap.items():
                if part in irq:
                    self.vectors[irq].extend(flags)

        if len(self.vectors) == 0:
            raise ValidateException("No interrupts found for Timer{}! Possible IRQs are {}"
                                    .format(self.instance, props["timer_vectors"]))
        if self.type != "advanced":
            if len(self.vectors) != 1:
                raise ValidateException("Timer{} is only allowed to have one IRQ! Found {}"
                                    .format(self.instance, self.vectors))


    def build(self, env):
        global props
        props["id"] = self.instance
        props["vectors"] = self.vectors
        props["types"].add(self.type)

        env.substitutions = props
        env.outbasepath = "modm/src/modm/platform/timer"
        env.template("{}.hpp.in".format(self.type), "timer_{}.hpp".format(self.instance))
        env.template("{}.cpp.in".format(self.type), "timer_{}.cpp".format(self.instance))


def init(module):
    module.name = ":platform:timer"
    module.description = "Timers (TIM)"

def prepare(module, options):
    device = options[":target"]
    if not device.has_driver("tim:stm32*"):
        return False

    module.depends(
        ":architecture:register",
        ":cmsis:device",
        ":platform:gpio",
        ":platform:rcc")

    timers = device.get_all_drivers("tim")
    instances = []
    for driver in timers:
        for instance in driver["instance"]:
            module.add_submodule(Instance(driver, instance))
            instances.append(int(instance))

    global props
    device = options[":target"]
    props["target"] = target = device.identifier
    props["types"] = set()
    # extended advanced control timer with 6 channels and trigger out 2
    props["advanced_extended"] = (target.family not in ["f0", "f1", "f2", "f3", "f4", "l0", "l1"]) or \
            ((target.family == "f3") and (target.name not in ["73", "78"]))

    props["timer_vectors"] = tim_vectors = defaultdict(list)
    vectors = [v["name"] for v in device.get_driver("core")["vector"] if "TIM" in v["name"]]
    for instance in instances:
        timstr = "TIM{}".format(instance)
        for vector in vectors:
            if (vector == timstr or ("_"+timstr+"_") in vector or
                vector.startswith(timstr+"_") or vector.endswith("_"+timstr)):
                tim_vectors[instance].append(vector)

    return True

def build(env):
    global props
    env.substitutions = props
    env.outbasepath = "modm/src/modm/platform/timer"

    pattern = re.compile("^(Ch([1-6])n{0,1})$")
    all_signals = env.query(":platform:gpio:all_signals")
    props["signals"] = [pattern.match(s).groups() for s in all_signals if pattern.match(s)]

    # Only generate the base types for timers that were generated
    types = props["types"]
    if "advanced" in types:
        types.add("general_purpose")
    if "general_purpose" in types:
        types.add("basic")
    for ttype in props["types"]:
        env.template("{}_base.hpp.in".format(ttype))

